Skip to content

Add SHy model and DiagnosisPrediction task#991

Open
Lilin-Huang wants to merge 3 commits intosunlabuiuc:masterfrom
Lilin-Huang:add-shy-model
Open

Add SHy model and DiagnosisPrediction task#991
Lilin-Huang wants to merge 3 commits intosunlabuiuc:masterfrom
Lilin-Huang:add-shy-model

Conversation

@Lilin-Huang
Copy link
Copy Markdown

@Lilin-Huang Lilin-Huang commented Apr 15, 2026

Contributor

Type of Contribution

Model + Task (Option 4)

Paper

Leisheng Yu, Yanxiao Cai, Minxing Zhang, and Xia Hu.
Self-Explaining Hypergraph Neural Networks for Diagnosis Prediction.
Proceedings of Machine Learning Research (CHIL), 2025.
https://arxiv.org/abs/2502.10689

Description

Implement SHy (Self-Explaining Hypergraph Neural Network) for diagnosis prediction. The model builds a patient hypergraph from diagnosis codes, runs UniGIN message passing, extracts K temporal phenotype sub-hypergraphs via Gumbel-Softmax sampling, aggregates each phenotype with a GRU + attention, and predicts next-visit diagnoses. A multi-objective loss combines prediction BCE, fidelity (reconstruction), distinctness (phenotype overlap penalty), and alpha diversity.

Also adds DiagnosisPredictionMIMIC3 and DiagnosisPredictionMIMIC4 standalone task classes that extract per-visit diagnosis histories from MIMIC-III/IV.

The example scripts run ablation studies over 4 axes: number of temporal phenotypes (K), HGNN layers, loss components, and Gumbel-Softmax temperature (novel extension).

File Guide

File Description
pyhealth/models/shy.py SHy model implementation
pyhealth/tasks/diagnosis_prediction.py MIMIC-III/IV diagnosis prediction tasks
pyhealth/models/__init__.py Register SHy import
pyhealth/tasks/__init__.py Register DiagnosisPrediction imports
tests/core/test_shy.py 15 unit tests with synthetic data
examples/mimic3_diagnosis_prediction_shy.py MIMIC-III ablation study script with results
examples/mimic4_diagnosis_prediction_shy.py MIMIC-IV ablation study script with results
docs/api/models/pyhealth.models.SHy.rst Model RST documentation
docs/api/tasks/pyhealth.tasks.DiagnosisPrediction.rst Task RST documentation
docs/api/models.rst Model index update
docs/api/tasks.rst Task index update

Ablation Results (MIMIC-III dev=True, 1000 patients, 50 epochs)

Config Jaccard F1 PR-AUC ROC-AUC
K=1 0.0339 0.0652 0.1732 0.7240
K=3 0.0401 0.0762 0.1294 0.6905
K=5 0.0402 0.0766 0.1533 0.7126
HGNN=0 0.0436 0.0827 0.1517 0.7067
HGNN=1 0.0413 0.0787 0.1398 0.6997
HGNN=2 0.0400 0.0759 0.1352 0.7142
no auxiliary loss 0.0426 0.0808 0.1671 0.7134
no fidelity 0.0420 0.0799 0.1422 0.6990
no distinct 0.0390 0.0743 0.1459 0.6905
no alpha 0.0408 0.0776 0.1429 0.6917
full (all loss) 0.0347 0.0666 0.1389 0.6881
temp=0.5 0.0408 0.0778 0.1265 0.7095
temp=1.0 0.0397 0.0757 0.1354 0.6961
temp=2.0 0.0411 0.0780 0.1431 0.6948

Ablation Results (MIMIC-IV dev=True, 1000 patients, 5 epochs)

Config Jaccard F1 PR-AUC ROC-AUC
K=1 0.0083 0.0163 0.1068 0.8590
K=3 0.0075 0.0149 0.1432 0.8694
K=5 0.0079 0.0157 0.0989 0.8576
HGNN=0 0.0079 0.0156 0.1277 0.8697
HGNN=1 0.0079 0.0156 0.1323 0.8699
HGNN=2 0.0082 0.0162 0.1081 0.8558
no auxiliary loss 0.0081 0.0160 0.1199 0.8610
no fidelity 0.0077 0.0153 0.1344 0.8628
no distinct 0.0084 0.0166 0.1242 0.8583
no alpha 0.0082 0.0162 0.1402 0.8685
full (all loss) 0.0082 0.0162 0.1134 0.8601
temp=0.5 0.0080 0.0159 0.1272 0.8678
temp=1.0 0.0080 0.0158 0.1145 0.8592
temp=2.0 0.0085 0.0168 0.1450 0.8691

Implement Self-Explaining Hypergraph Neural Network (SHy) for
diagnosis prediction, with MIMIC-III and MIMIC-IV task classes.

New files:
- pyhealth/models/shy.py: SHy model implementation
- pyhealth/tasks/diagnosis_prediction.py: DiagnosisPrediction tasks
- tests/core/test_shy.py: 12 unit tests with synthetic data
- examples/mimic4_diagnosis_prediction_shy.py: ablation study script
- docs/api/models/pyhealth.models.SHy.rst: model documentation
- docs/api/tasks/pyhealth.tasks.DiagnosisPrediction.rst: task documentation

Paper: Yu et al., "Self-Explaining Hypergraph Neural Networks for
Diagnosis Prediction", CHIL 2025.
@joshuasteier
Copy link
Copy Markdown
Collaborator

Hi @Lilin-Huang. SHy is a non-trivial model to implement (hypergraph message passing, Gumbel-Softmax phenotype extraction, multi-objective loss), and you have it all in working form with real ablation numbers from MIMIC-IV. The BaseModel and BaseTask integration is correct, the __init__.py changes are purely additive, and the ablation tables in the PR description are the kind of honest reproducibility reporting we want to see. A few things worth addressing.

Scalability: forward is sample-by-sample.

for i in range(batch_size):
    H = self._build_incidence_matrix(codes_batch[i]).to(self.device)
    ...
    tp_mats, tp_embs = self._encode_patient(X, H)

Each sample gets its own incidence matrix, HGNN pass, K extractors, and decoder call inside a Python loop. This will not scale beyond the 1000-patient dev subset used in your ablations. Per-patient hypergraphs make full batching nontrivial, but block-diagonal stacking (a standard trick from PyTorch Geometric) would let you batch the HGNN and phenotype extraction.

Not a merge blocker, but please add a note in the class docstring documenting the expected scalability (e.g., "This implementation processes samples sequentially in forward. For large batches or datasets, consider batching via block-diagonal hypergraph construction").

logit in the output dict is actually probabilities.

pred = torch.sigmoid(self.predict(combined))
...
return {"loss": loss, "y_prob": pred, "y_true": y_true, "logit": pred}

pred is post-sigmoid. Other PyHealth models return pre-sigmoid logits in the logit key. This will confuse downstream code that expects to apply sigmoid itself, for example calibration or temperature scaling. Could you either return the pre-sigmoid linear output in logit, or document in the docstring that logit is probability-valued for this model?

Use ValueError instead of assert for argument validation.

assert len(self.label_keys) == 1, "SHy supports exactly one label key (multilabel)"
assert len(self.feature_keys) == 1, "SHy expects exactly one feature key (nested_sequence)"

Asserts are stripped when Python is run with -O. For argument validation that should always run, raise ValueError instead. Small fix, two lines each.

Empty-history samples contribute to prediction loss from zero-vector input.

if H.sum() == 0:
    zero = torch.zeros(self.num_tp, self.hidden_dim, device=self.device)
    latent_list.append(zero)
    continue

When a patient has an all-zero incidence matrix (possible if all codes got clipped or visits are empty), the classifier still runs on the zero vector and the loss still counts. The task filters to len(visits) >= 2, so this should be rare in practice, but if it does happen the model learns to predict whatever the dataset average is for that position. Could you either skip these samples in the loss (exclude their indices from prediction BCE), or add a comment explaining the trade-off?

Fidelity loss does not handle class imbalance.

Your prediction loss correctly reweights positives:

num_pos = y_true.sum(dim=1, keepdim=True).clamp(min=1)
num_neg = (y_true.shape[1] - num_pos).clamp(min=1)
pos_weight = (num_neg / num_pos).expand_as(y_true)

But the fidelity loss is plain BCE on the incidence matrix:

F.binary_cross_entropy(r.clamp(1e-9, 1 - 1e-9), h.float())

The incidence matrix is typically very sparse (most codes are not in a visit). Without reweighting, fidelity encourages predicting all zeros, which is the uninformative solution. Your ablation shows fidelity has small positive weight (0.1) so this may be contained, but worth applying the same pos_weight treatment or commenting on why raw BCE is intended.

Tests verify it runs but do not verify the core model behavior.

The 12 tests cover shapes, output keys, backward, probability range, and a few hyperparameter variations. None of them check:

  1. That for num_tp=K, the model actually produces K distinct phenotype sub-hypergraphs (e.g., len(tp_list[0]) == K inside the forward path).
  2. That different configurations produce different outputs (for example, that num_tp=1 and num_tp=5 give measurably different losses on the same input).
  3. That the add_ratio false-negative addition changes the incidence matrix in a non-trivial way.

These would strengthen the PR by making it harder for future refactors to silently break the paper's core mechanism. A test that inspects the shape and sparsity of what PhenotypeExtractor produces would be a good start.

Teacher forcing in the decoder during evaluation.

# Teacher forcing
prev_codes = H.T[t]

The decoder uses ground-truth codes from visit t as input for predicting visit t+1, both during training and evaluation. During eval this means the fidelity reconstruction is computed with ground-truth inputs, which inflates the reconstruction quality. For the paper's reported fidelity numbers this is fine (they are training-time regularization metrics), but if anyone looks at eval fidelity they will over-interpret it. Could you gate teacher forcing on self.training, or add a docstring note?

Smaller items

  1. The docstring example for SHy:
>>> out = model(**batch)
>>> out["loss"]

Shows no expected output. Consider either removing the last line or adding a torch.Size(...) or tensor(...) to make it a valid doctest.

  1. _build_incidence_matrix uses a Python for-loop over visits. scatter_ on a flattened index tensor would vectorize this. Minor perf.

  2. num_edges = E.max().item() + 1 in UniGINConv.forward crashes if E is empty. The outer if H.sum() == 0: continue in forward guards against this, but a defensive check in the layer would make it safer for direct use.

  3. The ablation table in the PR is valuable but would be more useful if you also pasted it (or a reduced version) into the example script's top docstring, so users can see the expected output without running the full MIMIC-IV pipeline.

  4. "pred only" in the ablation table is confusing as a label. Could you rename to "no auxiliary loss" or similar for clarity?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants